import torch
from torch import nn
import torch.optim as optim
import numpy as np

def get_tensor_features(G, features, flows):
    '''
    '''
    use_cuda = torch.cuda.is_available()
    
    n_features = features[list(features.keys())[0]].shape[0]
    feat = np.zeros((len(flows), n_features))
    
    i = 0
    for e in G.edges():
        if e in flows:
            feat[i] = features[e]
            i = i + 1
        
    if use_cuda:
        return torch.cuda.FloatTensor(feat)
    else:
        return torch.FloatTensor(feat)
    
def get_tensor_flows(G, flows):
    '''
    '''
    use_cuda = torch.cuda.is_available()
    
    f = np.zeros(len(flows))
    
    i = 0
    for e in G.edges():
        if e in flows:
            f[i] = flows[e]
            i = i + 1
    
    if use_cuda:
        return torch.cuda.FloatTensor(f)
    else:
        return torch.FloatTensor(f)
    
def get_dict_flows(G, tf, edges):
    '''
    '''
    flows = {}
    
    i = 0
    for e in G.edges():
        if e in edges:
            flows[e] = tf[i,0].item()
            i = i + 1
            
    return flows
    
class Net(nn.Module):
    '''
        Simple MLP with ReLU activation in the hidden layer
        and sigmoid activation in the output layer
    '''
    def __init__(self, n_features, n_hidden, n_iter, lr, early_stop=10, output_activation=torch.sigmoid):
        super(Net, self).__init__()
        self.layer1 = nn.Linear(n_features, n_hidden)
        self.layer2 = nn.Linear(n_hidden, 1)
        self.optimizer = optim.SGD(self.parameters(), lr=lr)
        self.n_iter = n_iter
        self.early_stop = early_stop
        self.output_activation = output_activation
        
        self.use_cuda = torch.cuda.is_available()
        
        if self.use_cuda:
            self.cuda()
        
    def forward(self, x):
        x = torch.relu(self.layer1(x))
        x = self.output_activation(self.layer2(x))
                
        return x

    def train(self, xs_train, ys_train, xs_valid, ys_valid, verbose=False):
        loss_func = nn.MSELoss()
        valid_losses = []
        for epoch in range(self.n_iter):
            self.optimizer.zero_grad()
            outputs_train = (self.forward(xs_train).T)[0]
            train_loss = loss_func(outputs_train, ys_train)
            outputs_valid = (self.forward(xs_valid).T)[0]
            valid_loss = loss_func(outputs_valid, ys_valid)
            train_loss.backward()
            self.optimizer.step()
            
            valid_losses.append(valid_loss.item())
            
            if epoch % 100 == 0 and verbose is True:
                print("epoch: ", epoch, " train loss = ", train_loss.item(), " valid loss = ", valid_loss.item())
                
            if epoch > self.early_stop and valid_losses[-1] > np.mean(valid_losses[-(self.early_stop+1):-1]):
                if verbose is True:
                    print("Early stopping...")
                break
    
